

# This code serves as an example for training the original model that substitutes humans for simulation experiments.
# To run this file, you need anothor image data files (.pkl) that do not overlap with CXR-A and CXR-B.
# Instead of providing these files, we offer the trained model files resulting from this code execution.
# Please use this code file for reference only.


import copy
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader, Dataset


device = torch.device("cpu")


class CNN1(nn.Module):
    def __init__(self):
        super(CNN1, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 16 * 16, 64)
        self.fc2 = nn.Linear(64, 2)
        self.to(device)


    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 16 * 16)
        _ = self.fc1(x)
        x = torch.relu(_)
        x = self.fc2(x)
        return x, _



class CustomDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]

        if self.transform:
            image = self.transform(image)

        return image, target


class CustomDataset_q(Dataset):
    def __init__(self, images,  transform=None):
        self.images = images

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        if self.transform:
            image = self.transform(image)

        return image

def training_cnn(model, dataA, dataB, iteration):

    def val(model, val_loader):

        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = model.forward(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)

                np_predicted = predicted.detach().cpu().numpy()
                np_labels = labels.detach().cpu().numpy()

                for s in range(len(np_predicted)):
                    if np_labels[s][np_predicted[s]] == 1:
                        correct += 1


        return (100 * correct / total)


    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])


    data_train_x=list(map(lambda x: x[1], dataA.items()))
    data_train_y=[]
    for a in range(len(data_train_x)):
        data_train_y.append([1, 0])
    data_train_x_=list(map(lambda x: x[1], dataB.items()))
    for a in range(len(data_train_x_)):
        data_train_y.append([0, 1])
    data_train_x=data_train_x+data_train_x_

    images_train = torch.tensor(np.array(data_train_x))
    targets_train = torch.tensor(np.array(data_train_y), dtype=torch.float)

    custom_dataset_training = CustomDataset(images_train, targets_train, transform=transform)
    train_loader = DataLoader(custom_dataset_training, batch_size=100, shuffle=True)


    num_epochs = iteration

    max_val_acc = 0
    best_model=copy.deepcopy(model)
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs, _ = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


    return best_model



for i in range(8):
    with open("A_data"+str(i+1)+".pkl", 'rb') as file:
        normal = pickle.load(file)
        abnormal=pickle.load(file)
        model=CNN1()
        model=training_cnn(model, normal, abnormal, 100)

        torch.save(model.state_dict(), 'modelA_'+str(i+1)+'.pth')

for i in range(8):
    with open("B_data" + str(i + 1) + ".pkl", 'rb') as file:
        edema = pickle.load(file)
        pneumonia = pickle.load(file)
        model = CNN1()
        model = training_cnn(model, edema, pneumonia, 100)
        torch.save(model.state_dict(), 'modelB_' + str(i + 1) + '.pth')

